{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建多层的 LSTM 网络实现 MNIST 分类\n",
    "\n",
    "通过本例，你可以了解到单层 LSTM 的实现，多层 LSTM 的实现。输入输出数据的格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "(55000, 784)\n"
     ]
    }
   ],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.contrib import rnn\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "# 设置 GPU 按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "# 首先导入数据，看一下数据的形式\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "print mnist.train.images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 一、首先设置好模型用到的各个超参数 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "input_size = 28      # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素\n",
    "timestep_size = 28   # 时序持续长度为28，即每做一次预测，需要先输入28行\n",
    "hidden_size = 256    # 隐含层的数量\n",
    "layer_num = 2        # LSTM layer 的层数\n",
    "class_num = 10       # 最后输出分类类别数量，如果是回归预测的话应该是 1\n",
    "\n",
    "_X = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, class_num])\n",
    "# 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式\n",
    "batch_size = tf.placeholder(tf.int32, [])  # 注意类型必须为 tf.int32, batch_size = 128\n",
    "keep_prob = tf.placeholder(tf.float32, [])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 二、开始搭建 LSTM 模型，其实普通 RNNs 模型也一样 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把784个点的字符信息还原成 28 * 28 的图片\n",
    "# 下面几个步骤是实现 RNN / LSTM 的关键\n",
    "####################################################################\n",
    "# **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) \n",
    "X = tf.reshape(_X, [-1, 28, 28])\n",
    "\n",
    "# # **步骤2：定义一层 LSTM_cell，只需要说明 hidden_size, 它会自动匹配输入的 X 的维度\n",
    "# lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)\n",
    "\n",
    "# # **步骤3：添加 dropout layer, 一般只设置 output_keep_prob\n",
    "# lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)\n",
    "\n",
    "# # **步骤4：调用 MultiRNNCell 来实现多层 LSTM\n",
    "# mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)\n",
    "# mlstm_cell = rnn.MultiRNNCell([lstm_cell for _ in range(layer_num)] , state_is_tuple=True)\n",
    "\n",
    "# 在 tf 1.0.0 版本中，可以使用上面的 三个步骤创建多层 lstm， 但是在 tf 1.2.1 版本中，可以通过下面方式来创建\n",
    "def lstm_cell():\n",
    "    cell = rnn.LSTMCell(hidden_size, reuse=tf.get_variable_scope().reuse)\n",
    "    return rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)\n",
    "\n",
    "mlstm_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(layer_num)], state_is_tuple = True)\n",
    "\n",
    "# **步骤5：用全零来初始化state\n",
    "init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)\n",
    "\n",
    "# **步骤6：方法一，调用 dynamic_rnn() 来让我们构建好的网络运行起来\n",
    "# ** 当 time_major==False 时， outputs.shape = [batch_size, timestep_size, hidden_size] \n",
    "# ** 所以，可以取 h_state = outputs[:, -1, :] 作为最后输出\n",
    "# ** state.shape = [layer_num, 2, batch_size, hidden_size], \n",
    "# ** 或者，可以取 h_state = state[-1][1] 作为最后输出\n",
    "# ** 最后输出维度是 [batch_size, hidden_size]\n",
    "# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)\n",
    "# h_state = state[-1][1]\n",
    "\n",
    "# *************** 为了更好的理解 LSTM 工作原理，我们把上面 步骤6 中的函数自己来实现 ***************\n",
    "# 通过查看文档你会发现， RNNCell 都提供了一个 __call__()函数，我们可以用它来展开实现LSTM按时间步迭代。\n",
    "# **步骤6：方法二，按时间步展开计算\n",
    "outputs = list()\n",
    "state = init_state\n",
    "with tf.variable_scope('RNN'):\n",
    "    for timestep in range(timestep_size):\n",
    "        if timestep > 0:\n",
    "            tf.get_variable_scope().reuse_variables()\n",
    "        # 这里的state保存了每一层 LSTM 的状态\n",
    "        (cell_output, state) = mlstm_cell(X[:, timestep, :],state)\n",
    "        outputs.append(cell_output)\n",
    "h_state = outputs[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 三、最后设置 loss function 和 优化器，展开训练并完成测试 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter0, step 200, training accuracy 0.90625\n",
      "Iter0, step 400, training accuracy 0.96875\n",
      "Iter1, step 600, training accuracy 0.96875\n",
      "Iter1, step 800, training accuracy 0.976562\n",
      "Iter2, step 1000, training accuracy 0.96875\n",
      "Iter2, step 1200, training accuracy 0.984375\n",
      "Iter3, step 1400, training accuracy 0.984375\n",
      "Iter3, step 1600, training accuracy 0.976562\n",
      "Iter4, step 1800, training accuracy 0.984375\n",
      "Iter4, step 2000, training accuracy 0.984375\n",
      "test accuracy 0.984\n"
     ]
    }
   ],
   "source": [
    "############################################################################\n",
    "# 以下部分其实和之前写的多层 CNNs 来实现 MNIST 分类是一样的。\n",
    "# 只是在测试的时候也要设置一样的 batch_size.\n",
    "\n",
    "# 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor，我们要分类的话，还需要接一个 softmax 层\n",
    "# 首先定义 softmax 的连接权重矩阵和偏置\n",
    "# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')\n",
    "# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')\n",
    "# 开始训练和测试\n",
    "W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)\n",
    "bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)\n",
    "y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)\n",
    "\n",
    "\n",
    "# 损失和评估函数\n",
    "cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))\n",
    "train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "for i in range(2000):\n",
    "    _batch_size = 128\n",
    "    batch = mnist.train.next_batch(_batch_size)\n",
    "    if (i+1)%200 == 0:\n",
    "        train_accuracy = sess.run(accuracy, feed_dict={\n",
    "            _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})\n",
    "        # 已经迭代完成的 epoch 数: mnist.train.epochs_completed\n",
    "        print \"Iter%d, step %d, training accuracy %g\" % ( mnist.train.epochs_completed, (i+1), train_accuracy)\n",
    "    sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})\n",
    "\n",
    "# 计算测试数据的准确率\n",
    "print \"test accuracy %g\"% sess.run(accuracy, feed_dict={\n",
    "    _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们一共只迭代不到5个epoch，在测试集上就已经达到了0.98的准确率，可以看出来 LSTM 在做这个字符分类的任务上还是比较有效的，而且我们最后一次性对 10000 张测试图片进行预测，才占了 725 MiB 的显存。而我们在之前的两层 CNNs 网络中，预测 10000 张图片一共用了 8721 MiB 的显存，差了整整 12 倍呀！！ 这主要是因为 RNN/LSTM 网络中，每个时间步所用的权值矩阵都是共享的，可以通过前面介绍的 LSTM 的网络结构分析一下，整个网络的参数非常少。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、可视化看看 LSTM 的是怎么做分类的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毕竟 LSTM 更多的是用来做时序相关的问题，要么是文本，要么是序列预测之类的，所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里，我想通过可视化的方式来帮助大家理解 LSTM 是怎么样一步一步地把图片正确的给分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (2, 2, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "# 手写的结果 shape\n",
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print X_batch.shape, y_batch.shape\n",
    "_outputs, _state = np.array(sess.run([outputs, state],feed_dict={\n",
    "            _X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size}))\n",
    "print '_outputs.shape =', np.asarray(_outputs).shape\n",
    "print 'arr_state.shape =', np.asarray(_state).shape\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (2, 2, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print X_batch.shape, y_batch.shape\n",
    "_outputs, _state = sess.run([outputs, state],feed_dict={\n",
    "            _X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size})\n",
    "print '_outputs.shape =', np.asarray(_outputs).shape\n",
    "print 'arr_state.shape =', np.asarray(_state).shape\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "看下面我找了一个字符 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]\n",
      " [ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]\n",
      " [ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.]\n",
      " [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]\n",
      " [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]\n",
      " [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]\n",
      " [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]\n",
      " [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]\n",
      " [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]\n",
      " [ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.]]\n"
     ]
    }
   ],
   "source": [
    "print mnist.train.labels[10:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们先来看看这个字符样子,上半部分还挺像 2 来的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADdBJREFUeJzt3X+MFPUZx/HPIz+CoZhIK4SAlGJME3IxYE44I5LW1kax\nBtFE6x/mGovXGNrYpNGqNdakMUFiS/yrEVICNK2lwV+kGmshtdZECD8CgijFNjSAyGlQa43JFe7p\nHzu0V739zt7u7M5wz/uVbG53np2dJxs+zOx+Z+dr7i4A8ZxTdgMAykH4gaAIPxAU4QeCIvxAUIQf\nCIrwA0ERfiAowg8ENbaTGzMzTicE2szdrZHntbTnN7NrzOygmb1lZve28loAOsuaPbffzMZI+quk\nqyUdlbRD0q3ufiCxDnt+oM06seefL+ktd/+7uw9I+q2kJS28HoAOaiX80yUdGfL4aLbs/5hZn5nt\nNLOdLWwLQMHa/oWfu6+WtFrisB+oklb2/MckXTjk8YxsGYCzQCvh3yHpYjP7kpmNl/QtSZuLaQtA\nuzV92O/up8zse5L+IGmMpLXu/nphnQFoq6aH+praGJ/5gbbryEk+AM5ehB8IivADQRF+ICjCDwRF\n+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E\n1dEpujG8ZcuWtVRfsGBB3dqrr76aXHfHjh3J+pYtW5L1Q4cOJetvvvlmso7ysOcHgiL8QFCEHwiK\n8ANBEX4gKMIPBEX4gaBamqXXzA5L+kjSaUmn3L075/khZ+ndtGlTsr5kyZJkfcyYMUW2U6iBgYFk\nfeXKlXVrDz74YNHtQI3P0lvEST5fdff3CngdAB3EYT8QVKvhd0lbzGyXmfUV0RCAzmj1sH+hux8z\nsymS/mhmb7r7y0OfkP2nwH8MQMW0tOd392PZ335JT0uaP8xzVrt7d96XgQA6q+nwm9lEM5t05r6k\nb0jaX1RjANqrlcP+qZKeNrMzr/Mbd3+hkK4AtF1L4/wj3tgoHeefPn16sr53795kfdKkScn6Y489\nlqz39/fXrfX09CTXnTVrVrLe1dWVrI8fPz5ZHxwcrFu7++67k+uuWrUqWcfwGh3nZ6gPCIrwA0ER\nfiAowg8ERfiBoAg/EBRDfQWYMGFCsr58+fJk/ZVXXknWt2/fPuKeirJ06dJkPfWTXUm66KKL6tZ2\n796dXLe7m5NCm8FQH4Akwg8ERfiBoAg/EBThB4Ii/EBQhB8IinF+tOS+++5L1h9++OG6tf3709d+\nueSSS5rqKTrG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUEXM0otR7IILLkjWb7zxxqZfO+/3/Ggv\n9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTuOL+ZrZX0TUn97t6VLZssaaOkWZIOS7rZ3d9vX5to\nl0svvTRZX7NmTbI+b968ZH1gYKBubdOmTcl10V6N7PnXSbrmU8vulbTV3S+WtDV7DOAskht+d39Z\n0slPLV4iaX12f72kGwruC0CbNfuZf6q7H8/uvyNpakH9AOiQls/td3dPXZvPzPok9bW6HQDFanbP\nf8LMpklS9re/3hPdfbW7d7s7sy4CFdJs+DdL6s3u90p6tph2AHRKbvjN7AlJr0r6spkdNbPvSFoh\n6WozOyTp69ljAGcRrtt/FpgxY0ay3tXVVbfW09OTXPeee+5J1idMmJCs53nuuefq1q6//vqWXhvD\n47r9AJIIPxAU4QeCIvxAUIQfCIrwA0Fx6e4OmDNnTrJ+5513JuvXXnttsj579uwR99Sojz/+OFl/\n4IEHkvW8nwSjPOz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoxvk7YMOGDcl63uWzy3TgwIFkfdKk\nScl66hyEDz/8MLnukSNHknW0hj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFpbs74KWXXkrWFy1a\n1JlGKub999Ozum/bti1Zv+OOO5L1t99+e8Q9jQZcuhtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBJU7\nzm9mayV9U1K/u3dlyx6SdIekd7On3e/uz+duLOg4/7nnnpusz5s3r23bvuWWW5L1KVOmJOuXXXZZ\nS9ufOXNm3drYsa1dTuLgwYPJ+uWXX1639sEHH7S07Sorcpx/naRrhlm+yt3nZrfc4AOoltzwu/vL\nkk52oBcAHdTKZ/7vm9lrZrbWzM4vrCMAHdFs+H8habakuZKOS/pZvSeaWZ+Z7TSznU1uC0AbNBV+\ndz/h7qfdfVDSGknzE89d7e7d7t7dbJMAitdU+M1s2pCHSyXtL6YdAJ2SO9ZiZk9I+oqkL5jZUUk/\nkfQVM5srySUdlvTdNvYIoA34PT/a6rbbbqtbW7ZsWXLdK664Ilk/55z0gevmzZvr1pYuXZpct5O5\nKBq/5weQRPiBoAg/EBThB4Ii/EBQhB8IiqE+VNZdd92VrK9atarp177qqquS9bzLrVcZQ30Akgg/\nEBThB4Ii/EBQhB8IivADQRF+IKjWrp0MtNEzzzyTrK9cuTJZHzduXN3alVdemVz3bB7nbxR7fiAo\nwg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+Csgbz16xYkWyvm3btiLbqYzBwcFkvZVrUUyePLnpdUcL\n9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTuOL+ZXShpg6SpklzSand/zMwmS9ooaZakw5Judvf3\n29fq6PXuu+8m6y+88EKyvmvXrrq1Rx99NLnuiy++mKyfPn06WW/FzJkzk/Xbb789WR87tvnTVHp6\nepped7RoZM9/StIP3X2OpB5Jy81sjqR7JW1194slbc0eAzhL5Ibf3Y+7++7s/keS3pA0XdISSeuz\np62XdEO7mgRQvBF95jezWZLmSdouaaq7H89K76j2sQDAWaLhD01m9jlJT0r6gbv/0+x/04G5u9eb\nh8/M+iT1tdoogGI1tOc3s3GqBf/X7v5UtviEmU3L6tMk9Q+3rruvdvdud+8uomEAxcgNv9V28b+U\n9Ia7/3xIabOk3ux+r6Rni28PQLvkTtFtZgsl/UXSPklnfmN5v2qf+38naaakf6g21Hcy57WYonsY\nU6ZMSdYfeeSRZL23tzdZT9m3b1+yfurUqWT9+eefT9bnz59ft7ZgwYLkuuedd16ynmdgYKBuLW+o\nb8+ePS1tu0yNTtGd+5nf3V+RVO/FvjaSpgBUB2f4AUERfiAowg8ERfiBoAg/EBThB4LKHecvdGOM\n8zdlxowZyfrjjz9et7Zo0aLkuhMnTmyqpyr45JNPkvXly5fXra1bt67gbqqj0XF+9vxAUIQfCIrw\nA0ERfiAowg8ERfiBoAg/EBTj/KNc3rUCFi9enKzfdNNNyfp1112XrKcuDZ53yfKNGzcm63v37k3W\njxw5kqyPVozzA0gi/EBQhB8IivADQRF+ICjCDwRF+IGgGOcHRhnG+QEkEX4gKMIPBEX4gaAIPxAU\n4QeCIvxAULnhN7MLzexPZnbAzF43s7uy5Q+Z2TEz25Pd0j8MB1ApuSf5mNk0SdPcfbeZTZK0S9IN\nkm6W9C93f7ThjXGSD9B2jZ7kM7aBFzou6Xh2/yMze0PS9NbaA1C2EX3mN7NZkuZJ2p4t+r6ZvWZm\na83s/Drr9JnZTjPb2VKnAArV8Ln9ZvY5SX+W9LC7P2VmUyW9J8kl/VS1jwa357wGh/1AmzV62N9Q\n+M1snKTfS/qDu/98mPosSb93966c1yH8QJsV9sMeMzNJv5T0xtDgZ18EnrFU0v6RNgmgPI18279Q\n0l8k7ZM0mC2+X9Ktkuaqdth/WNJ3sy8HU6/Fnh9os0IP+4tC+IH24/f8AJIIPxAU4QeCIvxAUIQf\nCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQeVewLNg70n6x5DHX8iWVVFVe6tqXxK9\nNavI3r7Y6BM7+nv+z2zcbKe7d5fWQEJVe6tqXxK9Naus3jjsB4Ii/EBQZYd/dcnbT6lqb1XtS6K3\nZpXSW6mf+QGUp+w9P4CSlBJ+M7vGzA6a2Vtmdm8ZPdRjZofNbF8283CpU4xl06D1m9n+Icsmm9kf\nzexQ9nfYadJK6q0SMzcnZpYu9b2r2ozXHT/sN7Mxkv4q6WpJRyXtkHSrux/oaCN1mNlhSd3uXvqY\nsJktkvQvSRvOzIZkZislnXT3Fdl/nOe7+48q0ttDGuHMzW3qrd7M0t9Wie9dkTNeF6GMPf98SW+5\n+9/dfUDSbyUtKaGPynP3lyWd/NTiJZLWZ/fXq/aPp+Pq9FYJ7n7c3Xdn9z+SdGZm6VLfu0RfpSgj\n/NMlHRny+KiqNeW3S9piZrvMrK/sZoYxdcjMSO9ImlpmM8PInbm5kz41s3Rl3rtmZrwuGl/4fdZC\nd58r6VpJy7PD20ry2me2Kg3X/ELSbNWmcTsu6WdlNpPNLP2kpB+4+z+H1sp874bpq5T3rYzwH5N0\n4ZDHM7JlleDux7K//ZKeVu1jSpWcODNJava3v+R+/svdT7j7aXcflLRGJb532czST0r6tbs/lS0u\n/b0brq+y3rcywr9D0sVm9iUzGy/pW5I2l9DHZ5jZxOyLGJnZREnfUPVmH94sqTe73yvp2RJ7+T9V\nmbm53szSKvm9q9yM1+7e8Zukxap94/83ST8uo4c6fc2WtDe7vV52b5KeUO0w8N+qfTfyHUmfl7RV\n0iFJWyRNrlBvv1JtNufXVAvatJJ6W6jaIf1rkvZkt8Vlv3eJvkp53zjDDwiKL/yAoAg/EBThB4Ii\n/EBQhB8IivADQRF+ICjCDwT1H/peiEuJ5L9xAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f013a4d90d0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X3 = mnist.train.images[13]\n",
    "img3 = X3.reshape([28, 28])\n",
    "plt.imshow(img3, cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看在分类的时候，一行一行地输入，分为各个类别的概率会是什么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 1, 256)\n",
      "(28, 256)\n"
     ]
    }
   ],
   "source": [
    "X3.shape = [-1, 784]\n",
    "y_batch = mnist.train.labels[0]\n",
    "y_batch.shape = [-1, class_num]\n",
    "\n",
    "X3_outputs = np.array(sess.run(outputs, feed_dict={\n",
    "            _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))\n",
    "print X3_outputs.shape\n",
    "X3_outputs.shape = [28, hidden_size]\n",
    "print X3_outputs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABf5JREFUeJzt3Vtu4zYAQNG4K8kys7FZn/oxLaYFgiCxRIvkPec3mLFM\n0VfUI87jOI43ADr+unsDAHgt4QeIEX6AGOEHiBF+gBjhB4gRfoAY4QeIEX6AGOEHiBF+gJilw//+\n8et4//j16ZcNffWzs6959f+5IuMwjrEdZ1QXVjN9+EfF+9UHhTMHKRMVuNL04b+D0FoZwc6EP2yl\nuK+0rTA74edyz162OhN3l8rg+4SftJ0OCqMOnPy20xgJPwziLIRZCT9AjPDDQu64nHPm3zmzmfPs\nTviBp90V9xljuhLhB7ay0pnGXTfkhR/gba0DxlnCD3DSSl8t8/Ym/AA5wg8QI/wAMcIPECP8ADHC\nDxAj/AAxwg8QI/wAMcIPECP8ADHCDxDzOI7El9EB8A8rfoAY4QeIEX6AGOEHiBF+gBjh/0Tl727S\n2tel9/pqq42t8C9gtUkFzE34AWKEny04K7rH+8evw9ivZ9vwzzYhZ9seoGvb8HMfBziYm/AzjVFn\nRQ5E8H/Th9+HFv5wyZArTB/+EXx4gLJk+M8YdSnCgQh4FeG/kHgDVxm5IFw6/FbKMJ7P2H6WDj/n\nOHAy0sj5Zd6eI/wwiDgxK+EHiBF+XsoqGO4n/AAxws+n3PiFfQk/wFtrsSP8ADHCz1NevTLabTW2\n03thPcIPECP8wFZWOju8a1uFHzZxJiKrhJI/zuwz4YfJiPA4d6ywZ9yfwg/wDTMG/FmP49jmvQDw\nDVb8ADHCDxAj/AAxwg8QI/wAMcL/Qzs90nWGcRjH2O5lxv0p/AAxwg8QI/wAMcIPECP8QMZKX9k8\nkvADxAg/QIzwMw2n4fAawg9M+UtGjCP8ADHCDxAzffidggJca/rwA3At4Z+Ap1mAVxJ+gBjhD1vp\nTGOlbYXZCT9AjPCzPWcK8H/Cz6dcWoF9CT9AzNLh/2pV+uzPYFfm/Dirje3S4Qfg54T/Qqsd9Ucx\nDuxop3kt/DDIV6HYJSIum65J+D8xYiKf+YDs9MEa9V4qY3vHvaszY3vH/n72de84iJ3Zn2e29XEc\nS817AE6y4geIEX6AGOEHiBF+gBjhB4gR/k+s9ojfLjwTPpaxHWe1sRV+gBjh56VWWxnBjoQfIEb4\nAWKEHyBG+AFitg3/Xd+098rXA3jGtuEH4HPC/0NW9cDqhB8gRvgBYoQfIEb4ARZ05n6j8APECD8w\nhK/ZnpfwA8QIP0CM8JPmUgRFwg8QI/ywCTdT+S7h5ykCA+sS/gWILHAl4WcJLmPAdaYPvw8732Ge\njGNsfxsxDnctaB7HYZ8ClEy/4gfgWsIPECP8ADHCDxAj/AAxwg8QI/w/5JnmcfyS1pye3Scz7s8Z\nt+kOwg8QI/wwGStSRhN+iHBA4V/CDxAj/AAxwk+ayx8UCT9AjPADxAg/QIzwA8QI/wLcgASuJPwA\nMcIPECP8ADHCDxAj/GzBDXD4PuEHnuYPm6xJ+AFihB8gRvhhIS6tcAXhB4gR/hexUuO/zAXe3u7r\ngvBzOVGD80YeFJLht/o+b6XxW2lbzzCvf3t2HM6M3x3jfuY1H8eRnycAKckVP0CZ8APECD9AjPAD\nxAg/QIzwA8RsG37PNI9lbO9hXo9TGtttww/A54QfIEb4AWKEHyBG+AFihB8gRvgBYoQfIEb4AWKE\nHyBG+AFihB8gRvgBYoQfIEb4AWKEHyBG+GEylT8Gwn2EHyBG+AFihB8gRvgBYoQfIEb4uZynUmBu\nyfC/f/w6Xh2nr17zju0pMbbjmLtrehyHfQZQklzxA5QJP0CM8APECD9AjPADxAg/QEwy/J49HsvY\njmPujlMa22T4AcqEHyBG+AFihB8gRvgBYoQfIEb4AWKEHyBG+AFihB8gRvgBYoQfIEb4AWKEHyBG\n+AFihB8gRvgBYoQfIEb4AWKEHyBG+AFihB8gRvgBYoQfIOZxHMfd2wDAC1nxA8QIP0CM8APECD9A\njPADxAg/QIzwA8QIP0CM8APECD9AjPADxAg/QIzwA8QIP0CM8APECD9AjPADxAg/QIzwA8QIP0CM\n8APECD9AjPADxPwN14hzvy09OzkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f011b2a5610>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "h_W = sess.run(W, feed_dict={\n",
    "            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias = sess.run(bias, feed_dict={\n",
    "            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias.shape = [-1, 10]\n",
    "\n",
    "bar_index = range(class_num)\n",
    "for i in xrange(X3_outputs.shape[0]):\n",
    "    plt.subplot(7, 4, i+1)\n",
    "    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])\n",
    "    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))\n",
    "    plt.bar(bar_index, pro[0], width=0.2 , align='center')\n",
    "    plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面的图中，为了更清楚地看到线条的变化，我把坐标都去了，每一行显示了 4 个图，共有 7 行，表示了一行一行读取过程中，模型对字符的识别。可以看到，在只看到前面的几行像素时，模型根本认不出来是什么字符，随着看到的像素越来越多，最后就基本确定了它是字符 3."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
